package com.jstarcraft.ai.jsat.classifiers.linear;

import static java.lang.Math.abs;
import static java.lang.Math.exp;
import static java.lang.Math.log;
import static java.lang.Math.max;
import static java.lang.Math.signum;

import java.util.Arrays;
import java.util.Random;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.classifiers.CategoricalResults;
import com.jstarcraft.ai.jsat.classifiers.ClassificationDataSet;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.exceptions.FailedToFitException;
import com.jstarcraft.ai.jsat.exceptions.UntrainedModelException;
import com.jstarcraft.ai.jsat.linear.DenseVector;
import com.jstarcraft.ai.jsat.linear.IndexValue;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.regression.RegressionDataSet;
import com.jstarcraft.ai.jsat.utils.random.RandomUtil;

/**
 * Implements the iterative and single threaded stochastic solver for
 * L<sub>1</sub> regularized linear regression problems SMIDAS (Stochastic
 * Mirror Descent Algorithm mAde Sparse). It performs very well when the number
 * of features is large relative to or greater than the number of data points.
 * It also works decently on smaller sparse data sets. <br>
 * Using the squared loss is equivalent to LASSO regression, and the LOG loss is
 * equivalent to logistic regression. <br>
 * <br>
 * Note: This implementation requires all feature values to be in the range [-1,
 * 1]. By default scaling is performed to [0,1] to preserve sparseness. If your
 * data is dense or has negative and positive feature values, scaling to [-1, 1]
 * may perform better. See {@link #setReScale(boolean) }<br>
 * <br>
 * See:<br>
 * <a href="http://eprints.pascal-network.org/archive/00005418/">Shalev-Shwartz,
 * S.,&amp;Tewari, A. (2009). <i>Stochastic Methods for
 * L<sub>1</sub>-regularized Loss Minimization</i>. 26th International
 * Conference on Machine Learning (Vol. 12, pp. 929–936).</a>
 * 
 * @author Edward Raff
 */
public class SMIDAS extends StochasticSTLinearL1 {

    private static final long serialVersionUID = -4888083541600164597L;
    private double eta;

    /**
     * Creates a new SMIDAS learner
     * 
     * @param eta the learning rate for each iteration
     */
    public SMIDAS(double eta) {
        this(eta, DEFAULT_EPOCHS, DEFAULT_REG, DEFAULT_LOSS);
    }

    /**
     * Creates a new SMIDAS learner
     * 
     * @param eta    the learning rate for each iteration
     * @param epochs the number of learning iterations
     * @param lambda the regularization penalty
     * @param loss   the loss function to use
     */
    public SMIDAS(double eta, int epochs, double lambda, Loss loss) {
        this(eta, epochs, lambda, loss, true);
    }

    /**
     * Creates a new SMIDAS learner
     * 
     * @param eta     the learning rate for each iteration
     * @param epochs  the number of learning iterations
     * @param lambda  the regularization penalty
     * @param loss    the loss function to use
     * @param reScale whether or not to rescale the feature values
     */
    public SMIDAS(double eta, int epochs, double lambda, Loss loss, boolean reScale) {
        setEta(eta);
        setEpochs(epochs);
        setLambda(lambda);
        setLoss(loss);
        setReScale(reScale);
    }

    /**
     * Sets the learning rate used during training
     * 
     * @param eta the learning rate to use
     */
    public void setEta(double eta) {
        if (Double.isNaN(eta) || Double.isInfinite(eta) || eta <= 0)
            throw new ArithmeticException("convergence parameter must be a positive value");
        this.eta = eta;
    }

    /**
     * Returns the current learning rate used during training
     * 
     * @return the learning rate in use
     */
    public double getEta() {
        return eta;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (w == null)
            throw new UntrainedModelException("Model has not been trained");
        Vec x = data.getNumericalValues();
        return loss.classify(wDot(x));
    }

    @Override
    public double regress(DataPoint data) {
        if (w == null)
            throw new UntrainedModelException("Model has not been trained");
        Vec x = data.getNumericalValues();
        return loss.regress(wDot(x));
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        if (dataSet.getNumNumericalVars() < 3)
            throw new FailedToFitException("SMIDAS requires at least 3 features");
        else if (dataSet.getClassSize() != 2)
            throw new FailedToFitException("SMIDAS only supports binary classification problems");
        Vec[] x = setUpVecs(dataSet);

        Vec obvMinV = DenseVector.toDenseVec(obvMin);
        Vec obvMaxV = DenseVector.toDenseVec(obvMax);
        Vec multitpliers = new DenseVector(obvMaxV.length());
        multitpliers.mutableAdd(maxScaled - minScaled);
        multitpliers.mutablePairwiseDivide(obvMaxV.subtract(obvMinV));

        boolean allZeroMins = true;
        for (double min : obvMin)
            if (min != 0)
                allZeroMins = false;
        double[] target = new double[x.length];
        for (int i = 0; i < dataSet.size(); i++) {
            // Copy and scale each value
            if (allZeroMins && minScaled == 0.0) {
                x[i].mutablePairwiseMultiply(multitpliers);
            } else// destroy all sparsity and our dreams
            {
                x[i] = x[i].subtract(obvMinV);
                x[i].mutablePairwiseMultiply(multitpliers);
                x[i].mutableAdd(minScaled);
            }
            target[i] = dataSet.getDataPointCategory(i) * 2 - 1;
        }
        train(x, target);
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        if (dataSet.getNumNumericalVars() < 3)
            throw new FailedToFitException("SMIDAS requires at least 3 features");
        Vec[] x = setUpVecs(dataSet);

        Vec obvMinV = DenseVector.toDenseVec(obvMin);
        Vec obvMaxV = DenseVector.toDenseVec(obvMax);
        Vec multitpliers = new DenseVector(obvMaxV.length());
        multitpliers.mutableAdd(maxScaled - minScaled);
        multitpliers.mutablePairwiseDivide(obvMaxV.subtract(obvMinV));

        boolean allZeroMins = true;
        for (double min : obvMin)
            if (min != 0)
                allZeroMins = false;
        double[] target = new double[x.length];
        for (int i = 0; i < dataSet.size(); i++) {
            if (allZeroMins && minScaled == 0.0) {
                x[i].mutablePairwiseMultiply(multitpliers);
            } else {
                // Copy and scale each value
                x[i] = x[i].subtract(obvMinV);
                x[i].mutablePairwiseMultiply(multitpliers);
                x[i].mutableAdd(minScaled);
            }
            target[i] = dataSet.getTargetValue(i);
        }

        train(x, target);
    }

    private void train(Vec[] x, double[] y) {
        final int m = x.length;
        final int d = x[0].length();
        final double p = 2 * Math.log(d);

        Vec theta = new DenseVector(d);
        double theta_bias = 0;
        double lossScore = 0;
        w = new DenseVector(d);

        Random rand = RandomUtil.getRandom();

        for (int t = 0; t < epochs; t++) {
            int i = rand.nextInt(m);

            lossScore = loss.deriv(w.dot(x[i]) + bias, y[i]);

            theta.mutableSubtract(eta * lossScore, x[i]);
            theta_bias -= eta * lossScore;

            for (IndexValue iv : theta) {
                int j = iv.getIndex();
                double theta_j = iv.getValue();// theta.get(j);
                theta.set(j, signum(theta_j) * max(0, abs(theta_j) - eta * lambda));
            }
            theta_bias = signum(theta_bias) * max(0, abs(theta_bias) - eta * lambda);

            final double thetaNorm = theta.pNorm(p);
            if (thetaNorm > 0) {
                // w = f^-1(theta)
                final double logThetaNorm = log(thetaNorm);
                for (int j = 0; j < w.length(); j++) {
                    double theta_j = theta.get(j);
                    w.set(j, signum(theta_j) * exp((p - 1) * log(abs(theta_j)) - (p - 2) * logThetaNorm));
                }
                bias = signum(theta_bias) * exp((p - 1) * log(abs(theta_bias)) - (p - 2) * logThetaNorm);
            } else {
                theta.zeroOut();
                theta_bias = 0;
                w.zeroOut();
                bias = 0;
            }
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public SMIDAS clone() {
        SMIDAS clone = new SMIDAS(eta, epochs, lambda, loss, reScale);
        if (this.w != null)
            clone.w = this.w.clone();
        clone.bias = this.bias;
        clone.minScaled = this.minScaled;
        clone.maxScaled = this.maxScaled;
        if (this.obvMin != null)
            clone.obvMin = Arrays.copyOf(this.obvMin, this.obvMin.length);
        if (this.obvMax != null)
            clone.obvMax = Arrays.copyOf(this.obvMax, this.obvMax.length);
        return clone;
    }

    private Vec[] setUpVecs(DataSet dataSet) {
        obvMin = new double[dataSet.getNumNumericalVars()];
        Arrays.fill(obvMin, Double.POSITIVE_INFINITY);
        obvMax = new double[dataSet.getNumNumericalVars()];
        Arrays.fill(obvMax, Double.NEGATIVE_INFINITY);
        Vec[] x = new Vec[dataSet.size()];
        for (int i = 0; i < dataSet.size(); i++) {
            x[i] = dataSet.getDataPoint(i).getNumericalValues();

            for (IndexValue iv : x[i]) {
                int j = iv.getIndex();
                double v = iv.getValue();
                obvMin[j] = Math.min(obvMin[j], v);
                obvMax[j] = Math.max(obvMax[j], v);
            }
        }

        if (x[0].isSparse())// Assume implicit min zeros from sparsity
            for (int i = 0; i < obvMin.length; i++)
                obvMin[i] = Math.min(obvMin[i], 0);

        if (!reScale) {
            for (double min : obvMin)
                if (min < -1)
                    throw new FailedToFitException("Values must be in the range [-1,1], " + min + " violation encountered");
            for (double max : obvMax)
                if (max > 1)
                    throw new FailedToFitException("Values must be in the range [-1,1], " + max + " violation encountered");

        }
        return x;
    }

}
